from threeML import *
from powerlaw_synchrotron import Synchrotron_PowerLaw
from threeML.io.file_utils import file_existing_and_readable
import yaml
import re
import warnings
import copy
from read_grb import read_grb
from glob import glob


# grab the GRB names from the folders
grbs_to_fit = [f.split('/')[-1] for f in glob('grbs/GRB*')]



# the following code sets up the models and GRID
# callbacks used on each function. The can be fed to
# any Bayesian sampler or have their minimizers
# changed.


def model_getter_band(id):
    """
    set model for the the band function
    """

    band = Band()
    band.alpha.max_value = 2
    band.alpha.min_value = -1.7

    band.alpha.set_uninformative_prior(Uniform_prior)
    band.beta.set_uninformative_prior(Uniform_prior)
    band.K.prior = Log_uniform_prior(lower_bound=1E-5, upper_bound=1E7)
    band.xp.prior = Log_uniform_prior(lower_bound=1E1, upper_bound=1E4)

    ps = PointSource(grbs_to_fit[id], 0, 0, spectral_shape=band)
    return Model(ps)


def my_callback_band(minimizer, likelihood_model):

    # This means that for each point in the grid the ROOT minimizer will be used
    minimizer.set_minimizer("ROOT")

    name = likelihood_model.point_sources.keys()[0]

    parameter = likelihood_model.point_sources[name].spectrum.main.Band.K
    grid = np.logspace(-4, 3, 6)

    minimizer.add_parameter_to_grid(parameter, grid)
    parameter = likelihood_model.point_sources[name].spectrum.main.Band.xp
    grid = np.logspace(1, 5, 3)

    minimizer.add_parameter_to_grid(parameter, grid)




def model_getter_sbpl(id):

    sbpl = SmoothlyBrokenPowerLaw()
    sbpl.K = 0.07
    #sbpl.K.max_value = 1E4
    sbpl.break_energy.max_value = 1E5
    sbpl.break_scale.free = True
    sbpl.break_scale.max_value = 1E4

    sbpl.pivot = 300.

    sbpl.alpha.min_value = -2.
    sbpl.beta.min_value = -5

    sbpl.alpha.set_uninformative_prior(Uniform_prior)
    sbpl.beta.set_uninformative_prior(Uniform_prior)
    sbpl.K.prior = Log_uniform_prior(lower_bound=1E-4, upper_bound=1E7)
    sbpl.break_energy.prior = Log_uniform_prior(
        lower_bound=1E1, upper_bound=1E4)
    sbpl.break_scale.prior = Log_uniform_prior(
        lower_bound=1E-5, upper_bound=1E4)

    ps = PointSource(grbs_to_fit[id], 0, 0, spectral_shape=sbpl)

    return Model(ps)


def my_callback_sbpl(minimizer, likelihood_model):
    # This means that for each point in the grid the ROOT minimizer will be used
    minimizer.set_minimizer("ROOT")

    name = likelihood_model.point_sources.keys()[0]

    parameter = likelihood_model.point_sources[
        name].spectrum.main.SmoothlyBrokenPowerLaw.break_scale
    grid = np.logspace(-5, 2, 3)

    minimizer.add_parameter_to_grid(parameter, grid)

    parameter = likelihood_model.point_sources[
        name].spectrum.main.SmoothlyBrokenPowerLaw.K
    grid = np.logspace(-4, 7, 6)

    minimizer.add_parameter_to_grid(parameter, grid)





def model_getter_synchrotron(id):
    spl = Synchrotron_PowerLaw()
    spl.index.min_value=2.2
    spl.index.value = 5.
    spl.index.mex_value=7.
    spl.gamma_max = 1E7
    
    
    spl.B.prior = Log_uniform_prior(lower_bound = 1E0, upper_bound =1E5)
    spl.K.prior = Log_uniform_prior(lower_bound = 1E-3, upper_bound =1E5)
    spl.index.set_uninformative_prior(Uniform_prior)
    
    ps = PointSource(grbs_to_fit[id],0,0,spectral_shape=spl)
    return Model(ps)


def my_callback_synchrotron(minimizer, likelihood_model):
    import numpy as np
    # This means that for each point in the grid the ROOT minimizer will be used
    minimizer.set_minimizer("ROOT")
    name = likelihood_model.point_sources.keys()[0]
    
    parameter = likelihood_model.point_sources[name].spectrum.main.Synchrotron_PowerLaw.B
    grid = np.logspace(-5, 0, 3)
    minimizer.add_parameter_to_grid(parameter, grid)
    
    parameter = likelihood_model.point_sources[name].spectrum.main.Synchrotron_PowerLaw.K
    grid = np.logspace(-5, 4, 5)
    minimizer.add_parameter_to_grid(parameter, grid)




